#!/usr/bin/env python
# coding: utf-8

# # Creating a Zero-Shot classifier based on BETO
# 
# This notebook/script fine-tunes a BETO (spanish bert, 'dccuchile/bert-base-spanish-wwm-cased') model on the spanish XNLI dataset.
# The fine-tuned model can then be fed to a Huggingface ZeroShot pipeline to obtain a ZeroShot classifier.

# In[ ]:


from datasets import load_dataset, Dataset, load_metric, load_from_disk
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import Trainer, TrainingArguments
import torch
from pathlib import Path
# from ray import tune
# from ray.tune.suggest.hyperopt import HyperOptSearch
# from ray.tune.schedulers import ASHAScheduler


# # Prepare the datasets

# In[ ]:


xnli_es = load_dataset("xnli", "es")


# In[ ]:


xnli_es


# >joeddav
# >Aug '20
# >
# >@rsk97 In addition, just make sure the model used is trained on an NLI task and that the **last output label corresponds to entailment** while the **first output label corresponds to contradiction**.
# 
# => We change the original `label` and use the `labels` column, which is required by a `AutoModelForSequenceClassification`

# In[ ]:


# see markdown above
def switch_label_id(row):
    if row["label"] == 0:
        return {"labels": 2}
    elif row["label"] == 2:
        return {"labels": 0}
    else:
        return {"labels": 1}

for split in xnli_es:
    xnli_es[split] = xnli_es[split].map(switch_label_id)


# ## Tokenize data

# In[ ]:


tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")


# In a first attempt i padded all data to the maximum length of the dataset (379). However, the traninig takes substanially longer with all the paddings, it's better to pass in the tokenizer to the `Trainer` and let the `Trainer` do the padding on a batch level.

# In[ ]:


# Figured out max length of the dataset manually
# max_length = 379
def tokenize(row):
    return tokenizer(row["premise"], row["hypothesis"], truncation=True, max_length=512)  #, padding="max_length", max_length=max_length)


# In[ ]:


data = {}
for split in xnli_es:
    data[split] = xnli_es[split].map(
        tokenize, 
        remove_columns=["hypothesis", "premise", "label"], 
        batched=True, 
        batch_size=128
    )


# In[ ]:


train_path = str(Path("./train_ds").absolute())
valid_path = str(Path("./valid_ds").absolute())

data["train"].save_to_disk(train_path)
data["validation"].save_to_disk(valid_path)


# In[ ]:


# We can use `datasets.Dataset`s directly

# class XnliDataset(torch.utils.data.Dataset):
#     def __init__(self, data):
#         self.data = data

#     def __getitem__(self, idx):
#         item = {key: torch.tensor(val) for key, val in self.data[idx].items()}
#         return item

#     def __len__(self):
#         return len(self.data)


# In[ ]:


def trainable(config):
    metric = load_metric("xnli", "es")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = predictions.argmax(axis=-1)
        return metric.compute(predictions=predictions, references=labels)
    
    model = AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3)

    training_args = TrainingArguments(
        output_dir='./results',          # output directory
        do_train=True,
        do_eval=True,
        evaluation_strategy="steps",
        eval_steps=500,
        load_best_model_at_end=True,
        metric_for_best_model="eval_accuracy",
        num_train_epochs=config["epochs"],              # total number of training epochs
        per_device_train_batch_size=config["batch_size"],  # batch size per device during training
        per_device_eval_batch_size=config["batch_size_eval"],   # batch size for evaluation
        warmup_steps=config["warmup_steps"],  # 500
        weight_decay=config["weight_decay"],  # 0.001               # strength of weight decay
        learning_rate=config["learning_rate"],  # 5e-05
        logging_dir='./logs',            # directory for storing logs
        logging_steps=250,
        #save_steps=500,  # ignored when using load_best_model_at_end
        save_total_limit=10,
        no_cuda=False,
        disable_tqdm=True,
    )
    
#     train_dataset = XnliDataset(load_from_disk(config["train_path"]))
#     valid_dataset = XnliDataset(load_from_disk(config["valid_path"]))
    train_dataset = load_from_disk(config["train_path"])
    valid_dataset = load_from_disk(config["valid_path"])

    
    trainer = Trainer(
        model,
        tokenizer=tokenizer,
        args=training_args,                  # training arguments, defined above
        train_dataset=train_dataset,         # training dataset
        eval_dataset=valid_dataset,          # evaluation dataset
        compute_metrics=compute_metrics,
    )
    
    trainer.train()


# In[ ]:


trainable(
    {
        "train_path": train_path,
        "valid_path": valid_path,
        "batch_size": 16,
        "batch_size_eval": 64,
        "warmup_steps": 500,
        "weight_decay": 0.001,
        "learning_rate": 5e-5,
        "epochs": 3,
    }
)


# # HPO

# In[ ]:


# config = {
#     "train_path": train_path,
#     "valid_path": valid_path,
#     "warmup_steps": tune.randint(0, 500),
#     "weight_decay": tune.loguniform(0.00001, 0.1),
#     "learning_rate": tune.loguniform(5e-6, 5e-4),
#     "epochs": tune.choice([2, 3, 4])
# }


# # In[ ]:


# analysis = tune.run(
#     trainable,
#     config=config,
#     metric="eval_acc",
#     mode="max",
#     #search_alg=HyperOptSearch(),
#     #scheduler=ASHAScheduler(),
#     num_samples=1,
# )


# # In[ ]:


# def model_init():
#     return AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3)

# trainer = Trainer(
#     args=training_args,                  # training arguments, defined above
#     train_dataset=train_dataset,         # training dataset
#     eval_dataset=valid_dataset,          # evaluation dataset
#     model_init=model_init,
#     compute_metrics=compute_metrics,
# )


# # In[ ]:


# best_trial = trainer.hyperparameter_search(
#     direction="maximize",
#     backend="ray",
#     n_trials=2,
#     # Choose among many libraries:
#     # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html
#     search_alg=HyperOptSearch(mode="max", metric="accuracy"),
#     # Choose among schedulers:
#     # https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
#     scheduler=ASHAScheduler(mode="max", metric="accuracy"),
#     local_dir="tune_runs",
# )

